# from continual_rl.policies.clear.clear_policy_config import ClearPolicyConfig
#
#
# class SanePolicyConfig(ClearPolicyConfig):
#
#     def __init__(self):
#         super().__init__()
#         self.allowed_uncertainty_scale_for_creation = [1.0, 10.0]
#         self.uncertainty_scale = 1.0
#         self.min_steps_before_force_create = 100000
#         self.max_nodes = 6
#         self.fraction_of_nodes_mergeable = 0.75  # Of max_nodes
#         self.create_adds_replay = False
#         self.clear_loss_coeff = 1.0
#         self.merge_by_frame = True
#         self.merge_by_batch = True  # Alternative: merge by average of entire buffer
#         self.uncertainty_scale_in_get_active = 1.0
#         self.merge_batch_scale = 5.0  # How many batches to use when computing the merge metric
#         self.visualize_nodes = False
#         self.keep_larger_reservoir_val_in_merge = True
#         self.creation_pattern = "asymmetric_reset_anchor"
#         self.use_slow_critic = False
#         self.slow_critic_update_cadence = 10000
#         self.only_create_from_active = True
#         self.slow_critic_ema_new_weight = -1.0  # -1 means use equally weighted average
#         self.usage_count_based_merge = False
#         self.train_all = False
#         self.duplicate_optimizer = True
#         self.static_ensemble = False  # Baseline
#         self.map_task_id_to_module = False
#         self.baseline_extended_arch = True
#         self.baseline_includes_uncertainty = True  # Necessary for SANE, overrides IMPALA
#
#     def _load_from_dict_internal(self, config_dict):
#         config = super()._load_from_dict_internal(config_dict)
#         assert int(self.keep_larger_reservoir_val_in_merge) + int(self.usage_count_based_merge) <= 1, "Only one merge strategy should be specified"
#         assert not self.map_task_id_to_module or self.static_ensemble, "map_task_id_to_module requires static_ensemble"
#         return config

from continual_rl.policies.clear.clear_policy_config import ClearPolicyConfig


class SanePolicyConfig(ClearPolicyConfig):

    def __init__(self):
        super().__init__()
        # 保留原有配置项，但添加新的相似度相关配置
        self.max_nodes = 16
        self.fraction_of_nodes_mergeable = 0.75  # Of max_nodes
        self.create_adds_replay = False
        self.merge_by_frame = True
        self.merge_by_batch = True  # Alternative: merge by average of entire buffer
        self.merge_batch_scale = 5.0  # How many batches to use when computing the merge metric
        self.visualize_nodes = False
        self.keep_larger_reservoir_val_in_merge = True
        self.use_slow_critic = False
        self.slow_critic_update_cadence = 10000
        self.slow_critic_ema_new_weight = -1.0  # -1 means use equally weighted average
        self.usage_count_based_merge = False
        self.train_all = True
        self.duplicate_optimizer = True
        self.static_ensemble = False  # Baseline
        self.map_task_id_to_module = False
        self.baseline_extended_arch = True
        self.baseline_includes_uncertainty = True  # Necessary for SANE, overrides IMPALA

        # 新增基于任务相似度的配置
        self.similarity_threshold = 0.28  # 相似度阈值，低于此值则创建新节点
        self.feature_steps = 500  # 用于特征提取的步数
        self.feature_dim = 512  # 特征向量的维度

        self.clear_loss_coeff = 1.0

        # 节点融合相关参数
        self.weighting_exponent = 2.0  # 新添加的关键参数
        self.freeze_lower_layers = False  # 是否冻结底层网络
        self.noise_scale = 0.1  # 顶层初始化噪声强度
        # self.min_affinity_threshold = 0.3
        # self.history_weight = 0.4
        # self.prediction_weight = 0.4
        # self.uncertainty_penalty = 0.2

        # # 蒸馏配置
        self.wasserstein_scale = 1.0  # Wasserstein距离转换相似度的比例因子
        self.distill_alpha = 0  # 蒸馏损失权重
        self.distill_neighbors = 2  # 蒸馏邻居数
        self.distill_update_freq = 1000  # 知识图更新频率

        # 弃用原有基于价值锚点的配置（可选保留）
        # 这些配置在基于相似度的实现中不再使用
        self.allowed_uncertainty_scale_for_creation = [1.0, 10.0]  # 弃用
        self.uncertainty_scale = 1.0  # 弃用
        self.min_steps_before_force_create = 100000  # 弃用
        self.uncertainty_scale_in_get_active = 1.0  # 弃用
        self.creation_pattern = "asymmetric_reset_anchor"  # 弃用
        self.only_create_from_active = True  # 弃用

    def _load_from_dict_internal(self, config_dict):
        config = super()._load_from_dict_internal(config_dict)
        assert int(self.keep_larger_reservoir_val_in_merge) + int(
            self.usage_count_based_merge) <= 1, "Only one merge strategy should be specified"
        assert not self.map_task_id_to_module or self.static_ensemble, "map_task_id_to_module requires static_ensemble"
        return config


